{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial - Comet ml experiments monitoring\n",
    "\n",
    "In this notebook, we will see how to monitor your experiments using the integrated **comet_ml** callbacks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train your Pythae model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.datasets as datasets\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n",
    "\n",
    "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n",
    "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import BetaVAE, BetaVAEConfig\n",
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines.training import TrainingPipeline\n",
    "from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_config = BaseTrainerConfig(\n",
    "    output_dir='my_model',\n",
    "    learning_rate=1e-4,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_epochs=10, # Change this to train the model a bit more,\n",
    "    steps_predict=3\n",
    ")\n",
    "\n",
    "\n",
    "model_config = BetaVAEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=16,\n",
    "    beta=2.\n",
    "\n",
    ")\n",
    "\n",
    "model = BetaVAE(\n",
    "    model_config=model_config,\n",
    "    encoder=Encoder_ResNet_VAE_MNIST(model_config), \n",
    "    decoder=Decoder_ResNet_AE_MNIST(model_config) \n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Before lauching the pipeline, you will need to build your `CometCallback`\n",
    "\n",
    "To be able to access this feature you will need:\n",
    "- a valid comet_ml acccount\n",
    "- the `comet_ml` package installed in your virtual env. You can install it by running (`pip install comet_ml`)\n",
    "- Your `api_key` when setting up the `CometCallback`. Note that you may need to run `comet init --api-key` to set up your api-key locally and be able to synchronize your offline runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Before being allowed to monitor your experiments you may need to run the following\n",
    "# !pip install comet_ml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create you callback\n",
    "from pythae.trainers.training_callbacks import CometCallback\n",
    "\n",
    "callbacks = [] # the TrainingPipeline expects a list of callbacks\n",
    "\n",
    "comet_cb = CometCallback() # Build the callback \n",
    "\n",
    "# SetUp the callback \n",
    "comet_cb.setup(\n",
    "    training_config=training_config, # training config\n",
    "    model_config=model_config, # model config\n",
    "    api_key=\"your_comet_api_key\", # specify your comet api-key\n",
    "    project_name=\"your_comet_project\", # specify your wandb project\n",
    "    #offline_run=True, # run in offline mode\n",
    "    #offline_directory='my_offline_runs' # set the directory to store the offline runs\n",
    ")\n",
    "\n",
    "callbacks.append(comet_cb) # Add it to the callbacks list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    training_config=training_config,\n",
    "    model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=train_dataset,\n",
    "    eval_data=eval_dataset,\n",
    "    callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!\n",
    ")\n",
    "# You can log to https://comet.com/your_comet_username/your_comet_project to monitor your training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Or you can alternatively ability to view the Comet UI in the jupyter notebook\n",
    "import comet_ml\n",
    "\n",
    "experiment = comet_ml.get_global_experiment()\n",
    "experiment.display()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
